// Copyright 2022 The Google Research Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef THIRD_PARTY_DNN_ACCELERATOR_HLS_SYSTEMC_WEIGHTCONTROLLER_H_
#define THIRD_PARTY_DNN_ACCELERATOR_HLS_SYSTEMC_WEIGHTCONTROLLER_H_

#include <mc_connections.h>
#include <systemc.h>

#include "src/AccelTypes.h"

template <typename DTYPE, int NROWS, int NCOLS>
SC_MODULE(WeightController) {
  sc_in<bool> CCS_INIT_S1(clk);
  sc_in<bool> CCS_INIT_S1(rstn);

  Connections::In<Params> paramsIn;

  Connections::Out<MemoryRequest> addressRequest;
  Connections::In<Pack1D<DTYPE, NROWS> > dataResponse;

  Connections::Out<int> bufferWriteAddress[2];
  Connections::Out<Pack1D<DTYPE, NROWS> > bufferWriteData[2];
  Connections::Out<int> bufferWriteControl[2];
  Connections::Out<int> bufferReadAddress[2];
  Connections::Out<int> bufferReadControl[2];

  Connections::Combinational<Params> fetcherParams;
  Connections::Combinational<Params> writerParams;
  Connections::Combinational<Params> readerParams;

  SC_CTOR(WeightController) {
    SC_THREAD(read_params);
    sensitive << clk.pos();
    async_reset_signal_is(rstn, false);

    SC_THREAD(fetcher);
    sensitive << clk.pos();
    async_reset_signal_is(rstn, false);

    SC_THREAD(reader);
    sensitive << clk.pos();
    async_reset_signal_is(rstn, false);

    SC_THREAD(writer);
    sensitive << clk.pos();
    async_reset_signal_is(rstn, false);
  }

  void fetcher() {
    fetcherParams.ResetRead();
    addressRequest.Reset();

    wait();

    while (true) {
      Params params = fetcherParams.Pop();

      int loop_counters[2][6];
      int loop_bounds[2][6];

#pragma hls_unroll yes
      for (int i = 0; i < 2; i++) {
        for (int j = 0; j < 6; j++) {
          loop_bounds[i][j] = params.loops[i][j];
        }
      }

      loop_bounds[1][params.inputXLoopIndex[1]] = 1;
      loop_bounds[1][params.inputYLoopIndex[1]] = 1;

      int C0 = NROWS;

#pragma hls_pipeline_init_interval 1
#pragma hls_pipeline_stall_mode flush
      for (loop_counters[0][0] = 0; loop_counters[0][0] < loop_bounds[0][0];
           loop_counters[0][0]++) {
        for (loop_counters[0][1] = 0; loop_counters[0][1] < loop_bounds[0][1];
             loop_counters[0][1]++) {
          for (loop_counters[0][2] = 0; loop_counters[0][2] < loop_bounds[0][2];
               loop_counters[0][2]++) {
            for (loop_counters[1][0] = 0;
                 loop_counters[1][0] < loop_bounds[1][0];
                 loop_counters[1][0]++) {
              for (loop_counters[1][1] = 0;
                   loop_counters[1][1] < loop_bounds[1][1];
                   loop_counters[1][1]++) {
                for (loop_counters[1][2] = 0;
                     loop_counters[1][2] < loop_bounds[1][2];
                     loop_counters[1][2]++) {
                  for (loop_counters[1][3] = 0;
                       loop_counters[1][3] < loop_bounds[1][3];
                       loop_counters[1][3]++) {
                    for (loop_counters[1][4] = 0;
                         loop_counters[1][4] < loop_bounds[1][4];
                         loop_counters[1][4]++) {
                      for (loop_counters[1][5] = 0;
                           loop_counters[1][5] < loop_bounds[1][5];
                           loop_counters[1][5]++) {
                        for (int c0 = 0; c0 < C0; c0++) {
                          int k2 = loop_counters[0][params.weightLoopIndex[0]];
                          int K2 = params.loops[0][params.weightLoopIndex[0]];
                          int k1 = loop_counters[1][params.weightLoopIndex[1]];
                          int K1 = params.loops[1][params.weightLoopIndex[1]];
                          int c1 =
                              loop_counters[1][params.reductionLoopIndex[1]];
                          int C1 =
                              params.loops[1][params.reductionLoopIndex[1]];
                          int fx = loop_counters[1][params.fxIndex];
                          int FX = loop_bounds[1][params.fxIndex];
                          int fy = loop_counters[1][params.fyIndex];

                          int c = c1 * C0 + c0;
                          int C = C1 * C0;
                          int k = k2 * K1 * NCOLS + k1 * NCOLS;
                          int K = K2 * K1 * NCOLS;

                          int baseAddress =
                              (fy * FX * C * K) + (fx * C * K) + (c * K) + k;

                          int burstSize = NCOLS;
                          MemoryRequest memRequest = {
                              params.WEIGHT_OFFSET + baseAddress, burstSize};
                          addressRequest.Push(memRequest);
                        }
                      }
                    }
                  }
                }
              }
            }
          }
        }
      }
    }
  }

  void writer() {
    writerParams.ResetRead();

    dataResponse.Reset();

    bufferWriteControl[0].Reset();
    bufferWriteControl[1].Reset();

    bufferWriteAddress[0].Reset();
    bufferWriteAddress[1].Reset();

    bufferWriteData[0].Reset();
    bufferWriteData[1].Reset();

    wait();
    while (true) {
      Params params = writerParams.Pop();

      bool bankSel = 0;

      int loop_counters[2][6];
      int loop_bounds[2][6];

#pragma hls_unroll yes
      for (int i = 0; i < 2; i++) {
        for (int j = 0; j < 6; j++) {
          loop_bounds[i][j] = params.loops[i][j];
        }
      }

      loop_bounds[1][params.inputXLoopIndex[1]] = 1;
      loop_bounds[1][params.inputYLoopIndex[1]] = 1;

      int C0 = NROWS;

#pragma hls_pipeline_init_interval 1
#pragma hls_pipeline_stall_mode flush
      // outer memory hierarchy
      for (loop_counters[0][0] = 0; loop_counters[0][0] < params.loops[0][0];
           loop_counters[0][0]++) {
        for (loop_counters[0][1] = 0; loop_counters[0][1] < params.loops[0][1];
             loop_counters[0][1]++) {
          for (loop_counters[0][2] = 0;
               loop_counters[0][2] < params.loops[0][2];
               loop_counters[0][2]++) {
            for (loop_counters[1][0] = 0;
                 loop_counters[1][0] < loop_bounds[1][0];
                 loop_counters[1][0]++) {
              bufferWriteControl[bankSel].Push(
                  loop_bounds[1][1] * loop_bounds[1][2] * loop_bounds[1][3] *
                  loop_bounds[1][4] * loop_bounds[1][5] * NROWS);

              for (loop_counters[1][1] = 0;
                   loop_counters[1][1] < loop_bounds[1][1];
                   loop_counters[1][1]++) {
                for (loop_counters[1][2] = 0;
                     loop_counters[1][2] < loop_bounds[1][2];
                     loop_counters[1][2]++) {
                  for (loop_counters[1][3] = 0;
                       loop_counters[1][3] < loop_bounds[1][3];
                       loop_counters[1][3]++) {
                    for (loop_counters[1][4] = 0;
                         loop_counters[1][4] < loop_bounds[1][4];
                         loop_counters[1][4]++) {
                      for (loop_counters[1][5] = 0;
                           loop_counters[1][5] < loop_bounds[1][5];
                           loop_counters[1][5]++) {
                        for (int c0 = 0; c0 < C0; c0++) {
                          int k1 = loop_counters[1][params.weightLoopIndex[1]];
                          int K1 = params.loops[1][params.weightLoopIndex[1]];
                          int c1 =
                              loop_counters[1][params.reductionLoopIndex[1]];
                          int C1 =
                              params.loops[1][params.reductionLoopIndex[1]];
                          int fx = loop_counters[1][params.fxIndex];
                          int FX = loop_bounds[1][params.fxIndex];
                          int fy = loop_counters[1][params.fyIndex];

                          int c = c1 * C0 + c0;
                          int C = C1 * C0;

                          Pack1D<DTYPE, NCOLS> data = dataResponse.Pop();

                          int address = (fy * FX * C * K1) + (fx * C * K1) +
                                        (c * K1) + k1;

                          bufferWriteAddress[bankSel].Push(address);
                          bufferWriteData[bankSel].Push(data);
                        }
                      }
                    }
                  }
                }
              }
              bankSel = !bankSel;
            }
          }
        }
      }
    }
  }

  void reader() {
    readerParams.ResetRead();

    bufferReadControl[0].Reset();
    bufferReadControl[1].Reset();
    bufferReadAddress[0].Reset();
    bufferReadAddress[1].Reset();

    wait();
    while (true) {
      Params params = readerParams.Pop();

      bool bankSel = 0;

      int loop_counters[2][6];
      int loop_bounds[2][6];

#pragma hls_unroll yes
      for (int i = 0; i < 2; i++) {
        for (int j = 0; j < 6; j++) {
          loop_bounds[i][j] = params.loops[i][j];
        }
      }

      loop_bounds[1][params.weightReuseIndex[1]] = 1;
      loop_bounds[1][params.weightReuseIndex[1]] = 1;

      int C0 = NROWS;

#pragma hls_pipeline_init_interval 1
#pragma hls_pipeline_stall_mode flush
      // outer memory hierarchy
      for (loop_counters[0][0] = 0; loop_counters[0][0] < params.loops[0][0];
           loop_counters[0][0]++) {
        for (loop_counters[0][1] = 0; loop_counters[0][1] < params.loops[0][1];
             loop_counters[0][1]++) {
          for (loop_counters[0][2] = 0;
               loop_counters[0][2] < params.loops[0][2];
               loop_counters[0][2]++) {
            for (loop_counters[1][0] = 0;
                 loop_counters[1][0] < loop_bounds[1][0];
                 loop_counters[1][0]++) {
              bufferReadControl[bankSel].Push(
                  loop_bounds[1][1] * loop_bounds[1][2] * loop_bounds[1][3] *
                  loop_bounds[1][4] * loop_bounds[1][5] * NROWS);

              for (loop_counters[1][1] = 0;
                   loop_counters[1][1] < loop_bounds[1][1];
                   loop_counters[1][1]++) {
                for (loop_counters[1][2] = 0;
                     loop_counters[1][2] < loop_bounds[1][2];
                     loop_counters[1][2]++) {
                  for (loop_counters[1][3] = 0;
                       loop_counters[1][3] < loop_bounds[1][3];
                       loop_counters[1][3]++) {
                    for (loop_counters[1][4] = 0;
                         loop_counters[1][4] < loop_bounds[1][4];
                         loop_counters[1][4]++) {
                      for (loop_counters[1][5] = 0;
                           loop_counters[1][5] < loop_bounds[1][5];
                           loop_counters[1][5]++) {
                        // loaded in reverse order
                        for (int c0 = C0 - 1; c0 >= 0; c0--) {
                          int k1 = loop_counters[1][params.weightLoopIndex[1]];
                          int K1 = params.loops[1][params.weightLoopIndex[1]];
                          int fx = loop_counters[1][params.fxIndex];
                          int FX = loop_bounds[1][params.fxIndex];
                          int fy = loop_counters[1][params.fyIndex];

                          int address = (fy * FX * C0 * K1) + (fx * C0 * K1) +
                                        (c0 * K1) + k1;

                          bufferReadAddress[bankSel].Push(address);
                        }
                      }
                    }
                  }
                }
              }
              bankSel = !bankSel;
            }
          }
        }
      }
    }
  }

  void read_params() {
    paramsIn.Reset();
    fetcherParams.ResetWrite();
    writerParams.ResetWrite();
    readerParams.ResetWrite();

    wait();

    while (true) {
      Params params = paramsIn.Pop();

      fetcherParams.Push(params);
      writerParams.Push(params);
      readerParams.Push(params);
    }
  }
};

#endif  // THIRD_PARTY_DNN_ACCELERATOR_HLS_SYSTEMC_WEIGHTCONTROLLER_H_
